{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of Deep Convolutional GANs\n",
    "Reference: https://arxiv.org/pdf/1511.06434.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Run the comment below only when using Google Colab\n",
    "# !pip install torch torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import datetime\n",
    "import os, sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from matplotlib.pyplot import imshow, imsave\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "MODEL_NAME = 'Conditional-DCGAN'\n",
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def to_onehot(x, num_classes=10):\n",
    "    assert isinstance(x, int) or isinstance(x, (torch.LongTensor, torch.cuda.LongTensor))\n",
    "    if isinstance(x, int):\n",
    "        c = torch.zeros(1, num_classes).long()\n",
    "        c[0][x] = 1\n",
    "    else:\n",
    "        x = x.cpu()\n",
    "        c = torch.LongTensor(x.size(0), num_classes)\n",
    "        c.zero_()\n",
    "        c.scatter_(1, x, 1) # dim, index, src value\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_sample_image(G, n_noise=100):\n",
    "    \"\"\"\n",
    "        save sample 100 images\n",
    "    \"\"\"\n",
    "    img = np.zeros([280, 280])\n",
    "    for j in range(10):\n",
    "        c = torch.zeros([10, 10]).to(DEVICE)\n",
    "        c[:, j] = 1\n",
    "        z = torch.randn(10, n_noise).to(DEVICE)\n",
    "        y_hat = G(z,c).view(10, 28, 28)\n",
    "        result = y_hat.cpu().data.numpy()\n",
    "        img[j*28:(j+1)*28] = np.concatenate([x for x in result], axis=-1)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    \"\"\"\n",
    "        Convolutional Discriminator for MNIST\n",
    "    \"\"\"\n",
    "    def __init__(self, in_channel=1, input_size=784, condition_size=10, num_classes=1):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.transform = nn.Sequential(\n",
    "            nn.Linear(input_size+condition_size, 784),\n",
    "            nn.LeakyReLU(0.2),\n",
    "        )\n",
    "        self.conv = nn.Sequential(\n",
    "            # 28 -> 14\n",
    "            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(512),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            # 14 -> 7\n",
    "            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            # 7 -> 4\n",
    "            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.LeakyReLU(0.2),\n",
    "            nn.AvgPool2d(4),\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            # reshape input, 128 -> 1\n",
    "            nn.Linear(128, 1),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "    \n",
    "    def forward(self, x, c=None):\n",
    "        # x: (N, 1, 28, 28), c: (N, 10)\n",
    "        x, c = x.view(x.size(0), -1), c.float() # may not need\n",
    "        v = torch.cat((x, c), 1) # v: (N, 794)\n",
    "        y_ = self.transform(v) # (N, 784)\n",
    "        y_ = y_.view(y_.shape[0], 1, 28, 28) # (N, 1, 28, 28)\n",
    "        y_ = self.conv(y_)\n",
    "        y_ = y_.view(y_.size(0), -1)\n",
    "        y_ = self.fc(y_)\n",
    "        return y_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "    \"\"\"\n",
    "        Convolutional Generator for MNIST\n",
    "    \"\"\"\n",
    "    def __init__(self, input_size=100, condition_size=10):\n",
    "        super(Generator, self).__init__()\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(input_size+condition_size, 4*4*512),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        self.conv = nn.Sequential(\n",
    "            # input: 4 by 4, output: 7 by 7\n",
    "            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(256),\n",
    "            nn.ReLU(),\n",
    "            # input: 7 by 7, output: 14 by 14\n",
    "            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.ReLU(),\n",
    "            # input: 14 by 14, output: 28 by 28\n",
    "            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),\n",
    "            nn.Tanh(),\n",
    "        )\n",
    "        \n",
    "    def forward(self, x, c):\n",
    "        # x: (N, 100), c: (N, 10)\n",
    "        x, c = x.view(x.size(0), -1), c.float() # may not need\n",
    "        v = torch.cat((x, c), 1) # v: (N, 110)\n",
    "        y_ = self.fc(v)\n",
    "        y_ = y_.view(y_.size(0), 512, 4, 4)\n",
    "        y_ = self.conv(y_) # (N, 28, 28)\n",
    "        return y_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "D = Discriminator().to(DEVICE)\n",
    "G = Generator().to(DEVICE)\n",
    "# D.load_state_dict('D_dc.pkl')\n",
    "# G.load_state_dict('G_dc.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                transforms.Normalize(mean=[0.5],\n",
    "                                std=[0.5])]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "mnist = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data_loader = DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "criterion = nn.BCELoss()\n",
    "D_opt = torch.optim.Adam(D.parameters(), lr=0.0005, betas=(0.5, 0.999))\n",
    "G_opt = torch.optim.Adam(G.parameters(), lr=0.0005, betas=(0.5, 0.999))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "max_epoch = 30 # need more than 20 epochs for training generator\n",
    "step = 0\n",
    "n_critic = 1 # for training more k steps about Discriminator\n",
    "n_noise = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real\n",
    "D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0/30, Step: 0, D Loss: 1.38820219039917, G Loss: 0.6656067967414856\n",
      "Epoch: 0/30, Step: 500, D Loss: 0.12251123040914536, G Loss: 2.9214797019958496\n",
      "Epoch: 1/30, Step: 1000, D Loss: 1.193413257598877, G Loss: 1.0453091859817505\n",
      "Epoch: 1/30, Step: 1500, D Loss: 0.9341679811477661, G Loss: 0.9606684446334839\n",
      "Epoch: 2/30, Step: 2000, D Loss: 1.1234525442123413, G Loss: 1.1498955488204956\n",
      "Epoch: 2/30, Step: 2500, D Loss: 0.6930141448974609, G Loss: 1.1742724180221558\n",
      "Epoch: 3/30, Step: 3000, D Loss: 1.3757328987121582, G Loss: 0.812011182308197\n",
      "Epoch: 3/30, Step: 3500, D Loss: 1.0801360607147217, G Loss: 0.9421672821044922\n",
      "Epoch: 4/30, Step: 4000, D Loss: 0.9499390125274658, G Loss: 0.9240808486938477\n",
      "Epoch: 4/30, Step: 4500, D Loss: 1.3822717666625977, G Loss: 1.1036415100097656\n",
      "Epoch: 5/30, Step: 5000, D Loss: 1.0304744243621826, G Loss: 1.405872106552124\n",
      "Epoch: 5/30, Step: 5500, D Loss: 0.9890630841255188, G Loss: 1.632412314414978\n",
      "Epoch: 6/30, Step: 6000, D Loss: 1.521059274673462, G Loss: 0.7919453382492065\n",
      "Epoch: 6/30, Step: 6500, D Loss: 1.140388011932373, G Loss: 0.607524037361145\n",
      "Epoch: 7/30, Step: 7000, D Loss: 1.1539878845214844, G Loss: 0.8947302103042603\n",
      "Epoch: 8/30, Step: 7500, D Loss: 1.2763843536376953, G Loss: 1.1148791313171387\n",
      "Epoch: 8/30, Step: 8000, D Loss: 1.1152751445770264, G Loss: 0.7313449382781982\n",
      "Epoch: 9/30, Step: 8500, D Loss: 1.2663581371307373, G Loss: 0.6541774272918701\n",
      "Epoch: 9/30, Step: 9000, D Loss: 1.3565425872802734, G Loss: 0.7318902611732483\n",
      "Epoch: 10/30, Step: 9500, D Loss: 1.2353792190551758, G Loss: 0.7434751987457275\n",
      "Epoch: 10/30, Step: 10000, D Loss: 1.3849306106567383, G Loss: 0.7219750285148621\n",
      "Epoch: 11/30, Step: 10500, D Loss: 1.3713154792785645, G Loss: 0.7351651191711426\n",
      "Epoch: 11/30, Step: 11000, D Loss: 1.4120464324951172, G Loss: 0.6873267889022827\n",
      "Epoch: 12/30, Step: 11500, D Loss: 1.3668935298919678, G Loss: 0.7004505395889282\n",
      "Epoch: 12/30, Step: 12000, D Loss: 1.3948414325714111, G Loss: 0.6682448387145996\n",
      "Epoch: 13/30, Step: 12500, D Loss: 1.3520605564117432, G Loss: 0.7200150489807129\n",
      "Epoch: 13/30, Step: 13000, D Loss: 1.38283371925354, G Loss: 0.6875632405281067\n",
      "Epoch: 14/30, Step: 13500, D Loss: 1.3800649642944336, G Loss: 0.6897081136703491\n",
      "Epoch: 14/30, Step: 14000, D Loss: 1.3227455615997314, G Loss: 0.7387114763259888\n",
      "Epoch: 15/30, Step: 14500, D Loss: 1.3912972211837769, G Loss: 0.7016972303390503\n",
      "Epoch: 16/30, Step: 15000, D Loss: 1.3620877265930176, G Loss: 0.7103703022003174\n",
      "Epoch: 16/30, Step: 15500, D Loss: 1.383725881576538, G Loss: 0.6901295781135559\n",
      "Epoch: 17/30, Step: 16000, D Loss: 1.3718159198760986, G Loss: 0.7076488733291626\n",
      "Epoch: 17/30, Step: 16500, D Loss: 1.3845096826553345, G Loss: 0.6948199272155762\n",
      "Epoch: 18/30, Step: 17000, D Loss: 1.3773895502090454, G Loss: 0.697970986366272\n",
      "Epoch: 18/30, Step: 17500, D Loss: 1.3563889265060425, G Loss: 0.7274670004844666\n",
      "Epoch: 19/30, Step: 18000, D Loss: 1.42695152759552, G Loss: 0.6805920600891113\n",
      "Epoch: 19/30, Step: 18500, D Loss: 1.389258623123169, G Loss: 0.7034153938293457\n",
      "Epoch: 20/30, Step: 19000, D Loss: 1.3753050565719604, G Loss: 0.7082930207252502\n",
      "Epoch: 20/30, Step: 19500, D Loss: 1.370561122894287, G Loss: 0.7599979639053345\n",
      "Epoch: 21/30, Step: 20000, D Loss: 1.391028642654419, G Loss: 0.6823822855949402\n",
      "Epoch: 21/30, Step: 20500, D Loss: 1.365208387374878, G Loss: 0.7175812721252441\n",
      "Epoch: 22/30, Step: 21000, D Loss: 1.3695178031921387, G Loss: 0.7152674198150635\n",
      "Epoch: 22/30, Step: 21500, D Loss: 1.3963592052459717, G Loss: 0.683794379234314\n",
      "Epoch: 23/30, Step: 22000, D Loss: 1.3914263248443604, G Loss: 0.7091646790504456\n",
      "Epoch: 24/30, Step: 22500, D Loss: 1.3689409494400024, G Loss: 0.6955894231796265\n",
      "Epoch: 24/30, Step: 23000, D Loss: 1.4021984338760376, G Loss: 0.6963061094284058\n",
      "Epoch: 25/30, Step: 23500, D Loss: 1.3541772365570068, G Loss: 0.7195296287536621\n",
      "Epoch: 25/30, Step: 24000, D Loss: 1.3771123886108398, G Loss: 0.7081385850906372\n",
      "Epoch: 26/30, Step: 24500, D Loss: 1.3679018020629883, G Loss: 0.7069792747497559\n",
      "Epoch: 26/30, Step: 25000, D Loss: 1.3764290809631348, G Loss: 0.6909903287887573\n",
      "Epoch: 27/30, Step: 25500, D Loss: 1.3768750429153442, G Loss: 0.699029266834259\n",
      "Epoch: 27/30, Step: 26000, D Loss: 1.3821161985397339, G Loss: 0.6960726976394653\n",
      "Epoch: 28/30, Step: 26500, D Loss: 1.3820538520812988, G Loss: 0.6842552423477173\n",
      "Epoch: 28/30, Step: 27000, D Loss: 1.3942863941192627, G Loss: 0.7066346406936646\n",
      "Epoch: 29/30, Step: 27500, D Loss: 1.3738797903060913, G Loss: 0.7084473967552185\n",
      "Epoch: 29/30, Step: 28000, D Loss: 1.4056565761566162, G Loss: 0.6934462785720825\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(max_epoch):\n",
    "    for idx, (images, labels) in enumerate(data_loader):\n",
    "        # Training Discriminator\n",
    "        x = images.to(DEVICE)\n",
    "        y = labels.view(batch_size, 1)\n",
    "        y = to_onehot(y).to(DEVICE)\n",
    "        x_outputs = D(x, y)\n",
    "        D_x_loss = criterion(x_outputs, D_labels)\n",
    "\n",
    "        z = torch.randn(batch_size, n_noise).to(DEVICE)\n",
    "        z_outputs = D(G(z, y), y)\n",
    "        D_z_loss = criterion(z_outputs, D_fakes)\n",
    "        D_loss = D_x_loss + D_z_loss\n",
    "        \n",
    "        D.zero_grad()\n",
    "        D_loss.backward()\n",
    "        D_opt.step()\n",
    "\n",
    "        if step % n_critic == 0:\n",
    "            # Training Generator\n",
    "            z = torch.randn(batch_size, n_noise).to(DEVICE)\n",
    "            z_outputs = D(G(z, y), y)\n",
    "            G_loss = criterion(z_outputs, D_labels)\n",
    "\n",
    "            D.zero_grad()\n",
    "            G.zero_grad()\n",
    "            G_loss.backward()\n",
    "            G_opt.step()\n",
    "        \n",
    "        if step % 500 == 0:\n",
    "            print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))\n",
    "        \n",
    "        if step % 1000 == 0:\n",
    "            G.eval()\n",
    "            img = get_sample_image(G, n_noise)\n",
    "            imsave('samples/{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')\n",
    "            G.train()\n",
    "        step += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# generation to image\n",
    "G.eval()\n",
    "imshow(get_sample_image(G, n_noise), cmap='gray')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def save_checkpoint(state, file_name='checkpoint.pth.tar'):\n",
    "    torch.save(state, file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Saving params.\n",
    "# torch.save(D.state_dict(), 'D_c.pkl')\n",
    "# torch.save(G.state_dict(), 'G_c.pkl')\n",
    "save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')\n",
    "save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
